# -*-Python-*-

import mesh_tensorflow.transformer.fixup_layers

transformer.LayerStack.sublayers_per_layer = [
    @transformer.sublayer_mask_padding,
    @fixup_layers.sublayer_fixup_shift,
    @transformer.sublayer_call_layer,
    @transformer.sublayer_dropout,
    @fixup_layers.sublayer_fixup_scale,
    @transformer.sublayer_residual]


# Encoder has shift and scale.
encoder/transformer.LayerStack.sublayers_final = [
    @transformer.sublayer_dropout,
    @transformer.sublayer_mask_padding,
    @fixup_layers.sublayer_fixup_shift,
    @fixup_layers.sublayer_fixup_scale]

# Decoder is shift only.
decoder/transformer.LayerStack.sublayers_final = [
    @transformer.sublayer_dropout,
    @transformer.sublayer_mask_padding,
    @fixup_layers.sublayer_fixup_shift]

fixup_layers.DenseReluDenseFixup.use_bias = False

# Overwrites third_party/py/mesh_tensorflow/transformer/gin/defaults.gin
encoder/transformer.make_layer_stack.layers = [
    @fixup_layers.SelfAttentionFixup,
    @fixup_layers.DenseReluDenseFixup,
]
decoder/transformer.make_layer_stack.layers = [
    @fixup_layers.SelfAttentionFixup,
    @fixup_layers.EncDecAttentionFixup,
    @fixup_layers.DenseReluDenseFixup,
]

# For AT5 default experiment, num_layers is configured in
# third_party/py/t5/models/gin/models/bi_bert_base.gin
num_layers = 12
transformer.make_layer_stack.num_layers = %num_layers

# d_ff is defined in bi_bert_base.gin.
fixup_layers.DenseReluDenseFixup.hidden_size = %d_ff
fixup_layers.DenseReluDenseFixup.dropout_rate = %dropout_rate

# num_layers * len(encoder_layers) + num_layers * len(decoder_layers)
num_blocks = 60
SelfAttentionFixup.num_blocks = %num_blocks
DenseReluDenseFixup.num_blocks = %num_blocks

# Fixup is more stable and performs better with tf.float32 than with bfloat16.
utils.get_variable_dtype.activation_dtype = "float32"

# Possible options: ["he", "glorot"]. "he" is much more stable.
fixup_layers.SelfAttentionFixup.default_init = "he"
fixup_layers.DenseReluDenseFixup.default_init = "he"

# Possible options: ["uniform", "truncated_normal"].
fixup_layers.SelfAttentionFixup.init_distribution = "truncated_normal"
fixup_layers.DenseReluDenseFixup.init_distribution = "truncated_normal"
